Allow StatsForecastModel to accept model as string or class#3058
Allow StatsForecastModel to accept model as string or class#3058dennisbader merged 5 commits intounit8co:masterfrom
Conversation
jakubchlapek
left a comment
There was a problem hiding this comment.
Hey @tmchow, thanks a lot for your first contribution :)
The first draft looks nice, but I have left some comments, regarding your changes.
I think that regarding this update we should for the time being keep the option of passing in model instances, while giving an explicit warning to the user that the option is deprecated. The biggest thing I can see currently is that the new str | class paths aren't tested. While all the tests still pass (because we still accept instances so the logic hasn't changed) we don't check the new functionality yet. Since it will be the default I think it would be nice to change the existing tests to leverage the str path instead, while also adding new tests that verify the model class and instance model creation work. Maybe @dennisbader has some other thoughts on how we want to approach the migration?
| @@ -110,8 +118,8 @@ def encode_year(idx): | |||
| >>> series = AirPassengersDataset().load() | |||
There was a problem hiding this comment.
the AutoARIMA import is no longer necessary
| >>> future_cov = datetime_attribute_timeseries(series, "month", cyclic=True, add_length=6) | ||
| >>> # define AutoARIMA parameters | ||
| >>> model = StatsForecastModel(model=AutoARIMA(season_length=12)) | ||
| >>> # define AutoARIMA parameters (using string and model_kwargs) |
There was a problem hiding this comment.
we can get rid of the clarification in () since it will be the default
| if not (isinstance(model_class, type) and issubclass(model_class, _TS)): | ||
| raise_log( | ||
| ValueError( | ||
| f"`{model}` is not a valid StatsForecast model class." | ||
| ), | ||
| logger, | ||
| ) |
There was a problem hiding this comment.
i'd suggest extracting this code into _validate_sf_model_class, as per the NF implementation. while this would work, I'd like to keep the implementations as close to each other as possible :)
| if isinstance(model, _TS): | ||
| # backwards compatibility: model passed as an instance | ||
| if model_kwargs: | ||
| warnings.warn( | ||
| "`model_kwargs` is ignored when `model` is an instance. " | ||
| "Pass `model` as a string or class to use `model_kwargs`.", | ||
| UserWarning, | ||
| stacklevel=2, | ||
| ) | ||
| self.model = model |
There was a problem hiding this comment.
i generally agree that for now the SFModel instances should be accepted at least until some further release, but the error should be more clear this path is now deprecated and will be removed (e.g. "DEPRECATED: <xyz>, not just if the user passes in model_kwargs).
additionally, instead of warnings.warn() use logger.warning() for it.
Adapted the NeuralForecastModel pattern so StatsForecastModel now accepts model as a string name (e.g., "AutoARIMA"), a class, or an instance (backwards compatible). Added model_kwargs parameter for passing constructor arguments when using string or class form. This makes it easier to define models via config files without importing all statsforecast model classes. Fixes unit8co#3055
- Remove unused AutoARIMA import from docstring example - Remove redundant (using string and model_kwargs) clarification - Extract validation into _validate_sf_model_class (matches NF pattern) - Use logger.warning() with DEPRECATED prefix instead of warnings.warn()
especially if it merges an updated upstream into a topic branch.
1d86c95 to
3cb2637
Compare
dennisbader
left a comment
There was a problem hiding this comment.
Thanks @tmchow for this PR and your first contribution 🚀
The PR looked like a good start, I have already applied some updates to make it ready to be merged.
After the tests have passed, I'll merge 💯
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #3058 +/- ##
==========================================
- Coverage 95.79% 95.73% -0.06%
==========================================
Files 158 158
Lines 17303 17315 +12
==========================================
+ Hits 16575 16577 +2
- Misses 728 738 +10 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary
StatsForecastModelnow acceptsmodelas a string name, a class, or an instance. This follows the same pattern already used byNeuralForecastModeland makes it much easier to define models via config files or without importing every statsforecast model class.Changes
Updated
darts/models/forecasting/sf_model.py:modelparameter now acceptsstr | type[_TS] | _TS(was_TSonly)model_kwargsparameter for passing constructor arguments when using string or class form_import_sf_model_class()static method that imports models fromstatsforecast.modelsby namemodel_kwargsis passed with an instance, aUserWarningis raisedThree usage patterns now work:
Testing
The change is backwards compatible with existing tests. The string and class resolution paths follow the same pattern as
NeuralForecastModel._import_nf_model_class()which is already tested.Fixes #3055
This contribution was developed with AI assistance (Claude Code).